# -*- coding: utf-8 -*-
import torch
from torch import nn
import copy
import numpy as np
import math
from math import cos, sin, pi
import sympy
import copy
import random
import matplotlib.pyplot as plt
from torch import nn
import torch
import torch.distributions as dist
from torchvision.transforms import Lambda
import math
import time
import gc
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import yaml



class simulate:
    def __init__(self, cfg, env):

        sample, hilllist, varlist = env.gen()
        self.sample = sample
        self.cfg = cfg
        
    def render(self, total_state): #5maps, 3view

        total_state = total_state.reshape(-1, self.cfg.ep_len +1, 2)
        torch.save(total_state, "main experiment/result"+ "/"+self.cfg.algorithm+str(self.cfg.sample_type)+ '.pt')
        x = total_state[:, :, 0].cpu().detach().numpy()*5 + 50
        y = total_state[:, :, 1].cpu().detach().numpy()*5 + 50

        # Assigning colors to each category
        colors = np.repeat(np.arange(len(x)), len(x[0]))

        # Plotting
        plt.figure(figsize=(10, 6))


        if len(total_state) == self.cfg.sk_num * self.cfg.copies:
            cmap = plt.get_cmap('hsv')
            norm = Normalize(vmin=0, vmax=len(total_state))  

            scalar_map = ScalarMappable(norm=norm, cmap=cmap)
            for i in range(len(x)):
                plt.plot(x[i], y[i], color=scalar_map.to_rgba(i),  alpha=0.6)
            cmap = LinearSegmentedColormap.from_list('white_red', ['white', 'red'])
            plt.imshow(self.sample, cmap=cmap)
            plt.xticks([])
            plt.yticks([])

            plt.savefig("main experiment/result"+ "/" + str(self.cfg.d)+self.cfg.algorithm+str(self.cfg.sample_type)+'color.png')
    
        elif len(total_state) == self.cfg.sk_num:
        
            #scatter = plt.scatter(x, y, c=colors, cmap='tab20', alpha=0.6)
            cmap = plt.get_cmap('tab20')
            norm = Normalize(vmin=colors.min(), vmax=colors.max())
            scalar_map = ScalarMappable(norm=norm, cmap=cmap)
            for i in range(len(x)):
                plt.plot(x[i], y[i], color=scalar_map.to_rgba(colors[i * len(x[0])]),  alpha=0.6, linewidth=3)
            cmap = LinearSegmentedColormap.from_list('white_red', ['white', 'red'])
            plt.imshow(self.sample, cmap=cmap)
            plt.xticks([])
            plt.yticks([])

            plt.savefig("main experiment/result"+"/" +str(self.cfg.d)+self.cfg.algorithm+str(self.cfg.sample_type)+  'point.png')
    
        else:
        
            #scatter = plt.scatter(x, y, c=colors, cmap='tab20', alpha=0.6)
            cmap = plt.get_cmap('tab20')
            norm = Normalize(vmin=colors.min(), vmax=colors.max())
            scalar_map = ScalarMappable(norm=norm, cmap=cmap)
            for i in range(len(x)):
                plt.plot(x[i], y[i], color=scalar_map.to_rgba(colors[i * len(x[0])]),  alpha=0.6, linewidth=3)
            cmap = LinearSegmentedColormap.from_list('white_red', ['white', 'red'])
            plt.imshow(self.sample, cmap=cmap)
            plt.xticks([])
            plt.yticks([])

            plt.savefig("main experiment/result"+"/" +str(self.cfg.d)+self.cfg.algorithm+str(self.cfg.sample_type)+  'point2.png')

        plt.close() 
